{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "711752cb-4f15-42a3-9838-a0c67f397771",
   "metadata": {},
   "source": [
    "# Bind runtime args\n",
    "\n",
    "Sometimes we want to invoke a Runnable within a Runnable sequence with constant arguments that are not part of the output of the preceding Runnable in the sequence, and which are not part of the user input. We can use `Runnable.bind()` to easily pass these arguments in.\n",
    "\n",
    "Suppose we have a simple prompt + model sequence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5dad8b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet  langchain langchain-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "950297ed-2d67-4091-8ea7-1d412d259d04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_openai import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f3fdf86d-155f-4587-b7cd-52d363970c1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EQUATION: x^3 + 7 = 12\n",
      "\n",
      "SOLUTION:\n",
      "Subtracting 7 from both sides of the equation, we get:\n",
      "x^3 = 12 - 7\n",
      "x^3 = 5\n",
      "\n",
      "Taking the cube root of both sides, we get:\n",
      "x = ∛5\n",
      "\n",
      "Therefore, the solution to the equation x^3 + 7 = 12 is x = ∛5.\n"
     ]
    }
   ],
   "source": [
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"Write out the following equation using algebraic symbols then solve it. Use the format\\n\\nEQUATION:...\\nSOLUTION:...\\n\\n\",\n",
    "        ),\n",
    "        (\"human\", \"{equation_statement}\"),\n",
    "    ]\n",
    ")\n",
    "model = ChatOpenAI(temperature=0)\n",
    "runnable = (\n",
    "    {\"equation_statement\": RunnablePassthrough()} | prompt | model | StrOutputParser()\n",
    ")\n",
    "\n",
    "print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929c9aba-a4a0-462c-adac-2cfc2156e117",
   "metadata": {},
   "source": [
    "and want to call the model with certain `stop` words:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "32e0484a-78c5-4570-a00b-20d597245a96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EQUATION: x^3 + 7 = 12\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "runnable = (\n",
    "    {\"equation_statement\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | model.bind(stop=\"SOLUTION\")\n",
    "    | StrOutputParser()\n",
    ")\n",
    "print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4bd641f-6b58-4ca9-a544-f69095428f16",
   "metadata": {},
   "source": [
    "## Attaching OpenAI functions\n",
    "\n",
    "One particularly useful application of binding is to attach OpenAI functions to a compatible OpenAI model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f66a0fe4-fde0-4706-8863-d60253f211c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "function = {\n",
    "    \"name\": \"solver\",\n",
    "    \"description\": \"Formulates and solves an equation\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"equation\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The algebraic expression of the equation\",\n",
    "            },\n",
    "            \"solution\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The solution to the equation\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"equation\", \"solution\"],\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f381f969-df8e-48a3-bf5c-d0397cfecde0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'solver', 'arguments': '{\\n\"equation\": \"x^3 + 7 = 12\",\\n\"solution\": \"x = ∛5\"\\n}'}}, example=False)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Need gpt-4 to solve this one correctly\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"Write out the following equation using algebraic symbols then solve it.\",\n",
    "        ),\n",
    "        (\"human\", \"{equation_statement}\"),\n",
    "    ]\n",
    ")\n",
    "model = ChatOpenAI(model=\"gpt-4\", temperature=0).bind(\n",
    "    function_call={\"name\": \"solver\"}, functions=[function]\n",
    ")\n",
    "runnable = {\"equation_statement\": RunnablePassthrough()} | prompt | model\n",
    "runnable.invoke(\"x raised to the third plus seven equals 12\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f07d7528-9269-4d6f-b12e-3669592a9e03",
   "metadata": {},
   "source": [
    "## Attaching OpenAI tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2cdeeb4c-0c1f-43da-bd58-4f591d9e0671",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_current_weather\",\n",
    "            \"description\": \"Get the current weather in a given location\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"location\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
    "                    },\n",
    "                    \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
    "                },\n",
    "                \"required\": [\"location\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2b65beab-48bb-46ff-a5a4-ef8ac95a513c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_zHN0ZHwrxM7nZDdqTp6dkPko', 'function': {'arguments': '{\"location\": \"San Francisco, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_aqdMm9HBSlFW9c9rqxTa7eQv', 'function': {'arguments': '{\"location\": \"New York, NY\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_cx8E567zcLzYV2WSWVgO63f1', 'function': {'arguments': '{\"location\": \"Los Angeles, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}]})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\").bind(tools=tools)\n",
    "model.invoke(\"What's the weather in SF, NYC and LA?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "poetry-venv",
   "language": "python",
   "name": "poetry-venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
